from .bce_loss import bce_loss
from .ce_loss import ce_loss
from .dice_loss import dice_loss
from .focal_loss import focal_loss


loss_ = {
    "bce_loss": bce_loss,
    "ce_loss": ce_loss,
    "dice_loss": dice_loss,
    "focal_loss": focal_loss
}


def loss_function(name, **kwargs):
    assert name in loss_.keys()
    if name == 'focal_loss':
        return focal_loss(**kwargs)
    return loss_[name]
